# sciq_mcq_evaluate_multi_lora_sequential.py
import os
import re
import math
import csv
import collections
import torch
import time
import warnings
import random
import hashlib
import multiprocessing
import gc
from typing import Tuple, List

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset

# -------------------------------
# Config (edit via env)
# -------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN", "please enter HF_TOKEN")
MODEL_NAME = os.environ.get("MODEL_NAME", "please enter MODEL_NAME")

# --- Full case LoRA dirs (three) ---
TENSORS_DIR_FULL_MLPATT = os.environ.get("TENSORS_DIR_FULL_MLPATT", "to enter values of TENSORS_DIR_FULL_MLPATT")
TENSORS_DIR_FULL_ATT = os.environ.get("TENSORS_DIR_FULL_ATT", "to enter values of TENSORS_DIR_FULL_ATT")
TENSORS_DIR_FULL_MLP = os.environ.get("TENSORS_DIR_FULL_MLP", "to enter values of TENSORS_DIR_FULL_MLP")

# --- Top1 case LoRA dirs (three) ---
TENSORS_DIR_TOP1_MLPATT = os.environ.get("TENSORS_DIR_TOP1_MLPATT", "to enter values of TENSORS_DIR_TOP1_MLPATT")
TENSORS_DIR_TOP1_ATT = os.environ.get("TENSORS_DIR_TOP1_ATT", "to enter values of TENSORS_DIR_TOP1_ATT")
TENSORS_DIR_TOP1_MLP = os.environ.get("TENSORS_DIR_TOP1_MLP", "to enter values of TENSORS_DIR_TOP1_MLP")

# --- Top3 case LoRA dirs (three) ---
TENSORS_DIR_TOP3_MLPATT = os.environ.get("TENSORS_DIR_TOP3_MLPATT", "to enter values of TENSORS_DIR_TOP3_MLPATT")
TENSORS_DIR_TOP3_ATT = os.environ.get("TENSORS_DIR_TOP3_ATT", "to enter values of TENSORS_DIR_TOP3_ATT")
TENSORS_DIR_TOP3_MLP = os.environ.get("TENSORS_DIR_TOP3_MLP", "to enter values of TENSORS_DIR_TOP3_MLP")

# CSV / LOG locations (placeholders)
CSV_OUT = os.environ.get("CSV_OUT", "to enter values of CSV_OUT")
LOG_DIR = os.environ.get("LOG_DIR", "to enter values of LOG_DIR")

# -------------------------------
# Hyperparameters (placeholders)
# -------------------------------
# These are intentionally left as placeholders in the environment defaults.
def get_required_int_env(name: str) -> int:
    v = os.environ.get(name)
    if v is None:
        raise RuntimeError(f"Environment variable {name} is not set. Please set it to an integer.")
    try:
        return int(v)
    except Exception:
        raise RuntimeError(f"Environment variable {name} must be an integer. Got: {v!r}")

# Parse required hyperparameters (will raise with clear message if not provided)
N = get_required_int_env("N_SAMPLES")
NUM_K_SAMPLING = get_required_int_env("NUM_K_SAMPLING")
BATCH_SIZE = get_required_int_env("BATCH_SIZE")
MAX_NEW_TOKENS = get_required_int_env("MAX_NEW_TOKENS")
PREVIEW = get_required_int_env("PREVIEW")

# -------------------------------
# Multiprocessing and CUDA tweaks
# -------------------------------
PARALLEL_WORKERS = int(os.environ.get("PARALLEL_WORKERS", str(max(1, (os.cpu_count() or 4) - 2))))
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

# Ensure CSV/LOG_DIR exist and CSV_OUT lives inside LOG_DIR (preserve basename if CSV_OUT provided)
os.makedirs(LOG_DIR, exist_ok=True)
if CSV_OUT and not os.path.isabs(CSV_OUT):
    CSV_OUT = os.path.join(LOG_DIR, CSV_OUT)
elif CSV_OUT:
    CSV_OUT = os.path.join(LOG_DIR, os.path.basename(CSV_OUT))
os.makedirs(os.path.dirname(CSV_OUT), exist_ok=True)

# -------------------------------
# Tokenization caches (shared across models)
# -------------------------------
token_cache_cpu = {}
device_input_cache = {}

def _prompt_hash(s: str) -> str:
    return hashlib.sha256(s.encode("utf-8")).hexdigest()

def _pin_if_cuda(tensor, device):
    try:
        if device.type == "cuda" and tensor.device.type == "cpu":
            return tensor.pin_memory()
    except Exception:
        pass
    return tensor

def get_tokenized_inputs(tok, prompt: str, device: torch.device):
    ph = _prompt_hash(prompt)
    dev_key = (ph, str(device))
    if dev_key in device_input_cache:
        return device_input_cache[dev_key]
    if ph not in token_cache_cpu:
        inputs_cpu = tok(prompt, return_tensors="pt")
        if getattr(tok, "pad_token", None) is None:
            tok.pad_token = tok.eos_token
        token_cache_cpu[ph] = inputs_cpu
    inputs_cpu = token_cache_cpu[ph]
    inputs_on_device = {}
    for k, v in inputs_cpu.items():
        try:
            v_pinned = _pin_if_cuda(v, device)
            inputs_on_device[k] = v_pinned.to(device, non_blocking=True)
        except Exception:
            inputs_on_device[k] = v.to(device)
    device_input_cache[dev_key] = inputs_on_device
    return inputs_on_device

# -------------------------------
# Load pretrained base (bfloat16 + low_cpu_mem_usage)
# -------------------------------
def load_pretrained(model_name, hf_token):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        low_cpu_mem_usage=True,
        token=hf_token
    )
    tok = AutoTokenizer.from_pretrained(model_name, token=hf_token, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
    except Exception:
        pass
    print(f"Loaded pretrained base model for inference.")
    model.eval()
    return model, tok

# -------------------------------
# Load LoRA (keeps mapping/loading logic but minimal prints)
# -------------------------------
def load_lora(base_model, hf_token, tensors_dir, r, alpha):
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        low_cpu_mem_usage=True,
        token=hf_token
    )
    tok = AutoTokenizer.from_pretrained(base_model, token=hf_token, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
    except Exception:
        pass

    cfg = LoraConfig(
        r=r, lora_alpha=alpha,
        target_modules=["q_proj","k_proj","v_proj","o_proj","down_proj","up_proj","gate_proj"],
        task_type="CAUSAL_LM"
    )
    peft_model = get_peft_model(model, cfg)

    state_raw = {}
    if tensors_dir and os.path.exists(tensors_dir):
        for root, _, files in os.walk(tensors_dir):
            for f in files:
                if f.endswith(".pt"):
                    k = f.replace(".pt", "")
                    path = os.path.join(root, f)
                    try:
                        tensor = torch.load(path, map_location="cpu")
                        state_raw[k] = tensor
                    except Exception:
                        pass
    else:
        if tensors_dir:
            print(f"Warning: tensors_dir provided but path does not exist: {tensors_dir}")
        else:
            print("Warning: no tensors_dir provided; returning PEFT wrapper without additional weights.")

    # prepare mapping to peft keys
    target_state = peft_model.state_dict()
    target_keys = set(target_state.keys())

    mapped = {}
    unmatched_raw = []
    unmatched_targets = set(target_keys)
    suffixes = ["", ".weight", ".default.weight"]
    for raw_k, tensor in state_raw.items():
        found = False
        if raw_k in target_keys:
            mapped[raw_k] = tensor
            unmatched_targets.discard(raw_k)
            found = True
        else:
            for suf in suffixes:
                cand = raw_k + suf
                if cand in target_keys:
                    mapped[cand] = tensor
                    unmatched_targets.discard(cand)
                    found = True
                    break
            if not found:
                for prefix in ["base_model.", "model.", ""]:
                    for suf in suffixes:
                        cand = prefix + raw_k + suf
                        if cand in target_keys:
                            mapped[cand] = tensor
                            unmatched_targets.discard(cand)
                            found = True
                            break
                    if found:
                        break
        if not found:
            unmatched_raw.append(raw_k)

    # apply mapped weights (no verbose printing)
    peft_model.load_state_dict(mapped, strict=False)

    # minimal record of success
    if tensors_dir:
        print(f"Loaded LoRA weights from: {tensors_dir}")
    peft_model.eval()
    tok.pad_token = tok.eos_token
    return peft_model, tok

# -------------------------------
# Prompt builder for SciQ (MCQ)
# -------------------------------
def build_sciq_prompt(question, opts, support=None):
    opts_text = "\n".join([f"{chr(65+i)}. {o}" for i,o in enumerate(opts)])
    prompt = (
        f"Question:\n{question.strip()}\n\n"
        f"Options:\n{opts_text}\n\n"
        "Answer with the letter of the correct option only (A, B, C, or D). Do NOT provide any explanation.\n"
        "Answer:"
    )
    return prompt

# -------------------------------
# Generation helpers (inference_mode + token cache usage)
# -------------------------------
def generate_greedy(model, tok, prompt, max_new_tokens=MAX_NEW_TOKENS):
    device = next(model.parameters()).device
    inputs = get_tokenized_inputs(tok, prompt, device)
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            pad_token_id=tok.eos_token_id
        )
    seq = out[0]
    in_len = inputs["input_ids"].shape[-1]
    cont_ids = seq[in_len:]
    return tok.decode(cont_ids, skip_special_tokens=True).strip()

def generate_samples(model, tok, prompt, num_samples=8, batch_size=2, max_new_tokens=256, temperature=0.7, top_p=0.90):
    device = next(model.parameters()).device
    inputs = get_tokenized_inputs(tok, prompt, device)
    candidates = []
    num_loops = math.ceil(num_samples / batch_size)
    produced = 0
    with torch.inference_mode():
        for _ in range(num_loops):
            this_bs = min(batch_size, num_samples - produced)
            gen_conf = GenerationConfig(
                max_new_tokens = max_new_tokens,
                do_sample = True,
                temperature = temperature,
                top_p = top_p,
                repetition_penalty = 1.1
            )
            outs = model.generate(
                **inputs,
                num_return_sequences = this_bs,
                generation_config = gen_conf,
                pad_token_id = tok.eos_token_id,
            )
            for j in range(outs.shape[0]):
                seq = outs[j]
                in_len = inputs["input_ids"].shape[-1]
                cont_ids = seq[in_len:]
                cont = tok.decode(cont_ids, skip_special_tokens=True).strip()
                lines = [ln for ln in cont.splitlines() if ln.strip() != ""]
                cont_first = lines[0].strip() if lines else ""
                candidates.append(cont_first)
            produced += this_bs
    return candidates

# -------------------------------
# Output parsing for MCQ
# -------------------------------
def normalize_text(s):
    return re.sub(r"\s+", " ", (s or "").strip().lower())

def extract_choice_from_text(text, options):
    if text is None:
        return None
    t = str(text).strip()
    m = re.search(r"\b([A-Da-d])\b", t)
    if m:
        return m.group(1).upper()
    t_norm = normalize_text(t)
    for i,opt in enumerate(options):
        opt_norm = normalize_text(opt)
        if not opt_norm:
            continue
        if opt_norm in t_norm or t_norm in opt_norm:
            return chr(65 + i)
        toks = re.findall(r"[A-Za-z0-9_.+-]+", t_norm)
        if toks and opt_norm.split()[0] == toks[0]:
            return chr(65 + i)
    return None

def choose_from_parsed(parsed_list, raw_list):
    cnt = collections.Counter([p for p in parsed_list if p is not None])
    if cnt:
        return cnt.most_common(1)[0][0]
    for p in parsed_list:
        if p is not None:
            return p
    for r in raw_list:
        m = re.search(r"\b([A-Da-d])\b", str(r))
        if m:
            return m.group(1).upper()
    return None

def _parse_and_choose(worker_args) -> Tuple[List[str], str]:
    cand_list, options = worker_args
    parsed = [extract_choice_from_text(c, options) for c in cand_list]
    chosen = choose_from_parsed(parsed, cand_list)
    return parsed, chosen

def _build_per_sample_text(worker_args) -> str:
    (i, total, question, options, gold_letter, corr_raw,
     cand, parsed, pred, is_corr, model_label) = worker_args

    per_sample_lines = []
    per_sample_lines.append("\n" + "="*80)
    per_sample_lines.append(f"[{i+1}/{total}] Question:")
    per_sample_lines.append(question)
    per_sample_lines.append("Options (A-D):")
    for idx,opt in enumerate(options):
        per_sample_lines.append(f" {chr(65+idx)}. {opt}")
    if gold_letter:
        per_sample_lines.append(f"Gold: {gold_letter} -> {options[ord(gold_letter)-65]}")
    else:
        per_sample_lines.append(f"Gold (raw text match failed): {corr_raw}")
    per_sample_lines.append(f"\n--- {model_label} ---")
    per_sample_lines.append("Raw candidates: " + str(cand))
    per_sample_lines.append("Parsed letters: " + str(parsed) + " Chosen: " + str(pred) + " | Correct? " + str(is_corr))
    per_sample_lines.append("="*80)
    return "\n".join(per_sample_lines)

# -------------------------------
# Evaluate single loaded model with tqdm progress
# -------------------------------
def evaluate_loaded_model_on_sciq(model, tok, ds, model_label, log_fh=None, pool=None):
    total = len(ds)
    correct = 0
    use_pool = (pool is not None)
    pbar = tqdm(total=total, desc=f"{model_label}", unit="it")
    try:
        for i, ex in enumerate(ds):
            question = ex.get("question", "").strip()
            d1 = ex.get("distractor1", "") or ""
            d2 = ex.get("distractor2", "") or ""
            d3 = ex.get("distractor3", "") or ""
            correct_raw = ex.get("correct_answer", "") or ""
            support = ex.get("support", None)

            original_options = [d1, d2, d3, correct_raw]
            rng = random.Random(i)
            options = original_options[:]
            rng.shuffle(options)

            gold_letter = None
            corr_norm = normalize_text(correct_raw)
            for idx, opt in enumerate(options):
                if normalize_text(opt) == corr_norm:
                    gold_letter = chr(65 + idx)
                    break
            if gold_letter is None:
                for idx, opt in enumerate(options):
                    if corr_norm and corr_norm in normalize_text(opt):
                        gold_letter = chr(65 + idx)
                        break

            prompt = build_sciq_prompt(question, options, support=support)

            if NUM_K_SAMPLING > 1:
                cands = generate_samples(model, tok, prompt, num_samples=NUM_K_SAMPLING, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS)
            else:
                out = generate_greedy(model, tok, prompt, max_new_tokens=MAX_NEW_TOKENS)
                lines = [ln for ln in out.splitlines() if ln.strip() != ""]
                cands = [lines[0].strip() if lines else ""]

            parsed, pred = _parse_and_choose((cands, options))
            is_corr = (gold_letter is not None and pred == gold_letter)
            correct += int(is_corr)

            should_log = (i % max(1, total)) == 0  # only the 0th sample will be forced; previews handled below
            if should_log:
                build_args = (i, total, question, options, gold_letter, correct_raw, cands, parsed, pred, is_corr, model_label)
                if use_pool:
                    per_sample_text = pool.apply(_build_per_sample_text, args=(build_args,))
                else:
                    per_sample_text = _build_per_sample_text(build_args)
                if i < PREVIEW:
                    print(per_sample_text)
                if log_fh:
                    try:
                        log_fh.write(per_sample_text + "\n")
                        log_fh.flush()
                    except Exception:
                        pass
            else:
                if i < PREVIEW:
                    snippet = question.splitlines()[0][:200] if question else ""
                    print(f"[{i+1}/{total}] {snippet} | {model_label} pred:{pred} | gold:{gold_letter}")

            pbar.update(1)
    finally:
        pbar.close()
    return correct

# -------------------------------
# Entrypoint
# -------------------------------
if __name__ == "__main__":
    dataset_name_for_log = "SciQ"
    os.makedirs(LOG_DIR, exist_ok=True)

    ds = load_dataset("allenai/sciq", split=f"test[:{N}]")
    total = len(ds)
    print(f"Loaded {total} SciQ items (using test[:{N}])")

    log_file_path = os.path.join(LOG_DIR, f"{dataset_name_for_log}_detailed_log.txt")
    try:
        log_fh = open(log_file_path, "w", encoding="utf-8")
    except Exception as e:
        print("Could not open log file for writing:", e)
        log_fh = None

    pool = None
    try:
        if PARALLEL_WORKERS > 1:
            pool = multiprocessing.Pool(processes=PARALLEL_WORKERS)
            print(f"Started multiprocessing pool with {PARALLEL_WORKERS} workers for CPU parsing/logging.")
    except Exception as e:
        pool = None
        print("Could not start multiprocessing pool; proceeding single-threaded for CPU work:", e)

    # 10 models: pretrained + 3 full + 3 top1 + 3 top3 (order matters)
    model_list = [
        ("Pretrained", None),  # None indicates pretrained base (no LoRA dir)
        ("Full_MLPATT", TENSORS_DIR_FULL_MLPATT),
        ("Full_ATT", TENSORS_DIR_FULL_ATT),
        ("Full_MLP", TENSORS_DIR_FULL_MLP),
        ("Top1_MLPATT", TENSORS_DIR_TOP1_MLPATT),
        ("Top1_ATT", TENSORS_DIR_TOP1_ATT),
        ("Top1_MLP", TENSORS_DIR_TOP1_MLP),
        ("Top3_MLPATT", TENSORS_DIR_TOP3_MLPATT),
        ("Top3_ATT", TENSORS_DIR_TOP3_ATT),
        ("Top3_MLP", TENSORS_DIR_TOP3_MLP),
    ]

    results_correct = {}
    results_acc = {}

    for label, dirpath in model_list:
        if label == "Pretrained":
            # load pretrained base
            print(f"\n--- Loading & evaluating: {label} ---\n")
            model_loaded, tok_loaded = load_pretrained(MODEL_NAME, HF_TOKEN)
        else:
            if not dirpath or dirpath.startswith("to enter") or not os.path.exists(dirpath):
                print(f"Skipping {label}: tensors dir missing or placeholder ({dirpath}).")
                results_correct[label] = 0
                results_acc[label] = 0.0
                continue
            print(f"\n--- Loading & evaluating: {label} from {dirpath} ---\n")
            model_loaded, tok_loaded = load_lora(MODEL_NAME, HF_TOKEN, dirpath, r=16, alpha=32)

        try:
            correct = evaluate_loaded_model_on_sciq(model_loaded, tok_loaded, ds, label, log_fh=log_fh, pool=pool)
            acc = correct / total if total > 0 else 0.0
            results_correct[label] = correct
            results_acc[label] = acc
        except Exception as e:
            print(f"Exception while evaluating {label}: {e}")
            results_correct[label] = 0
            results_acc[label] = 0.0

        # unload
        try:
            del model_loaded
            del tok_loaded
            torch.cuda.empty_cache()
            gc.collect()
            print(f"Unloaded {label} and freed GPU memory.\n")
        except Exception:
            pass

    if pool:
        pool.close()
        pool.join()

    # final summary print
    print("\n\nFINAL AGGREGATED ACCURACY (ALL MODELS) — Dataset: {}  N = {}\n".format(dataset_name_for_log, total))
    for label, _ in model_list:
        correct = results_correct.get(label, 0)
        acc = results_acc.get(label, 0.0)
        print(f"{label}: {correct} / {total} = {acc:.3f}")

    # write CSV summary (10 model columns)
    try:
        file_exists = os.path.exists(CSV_OUT)
        with open(CSV_OUT, "a", newline="", encoding="utf-8") as fh:
            writer = csv.writer(fh)
            if not file_exists:
                writer.writerow([
                    "dataset", "N",
                    "pretrained_correct",
                    "full_mlpatt_correct", "full_att_correct", "full_mlp_correct",
                    "top1_mlpatt_correct", "top1_att_correct", "top1_mlp_correct",
                    "top3_mlpatt_correct", "top3_att_correct", "top3_mlp_correct",
                    "acc_pretrained",
                    "acc_full_mlpatt", "acc_full_att", "acc_full_mlp",
                    "acc_top1_mlpatt", "acc_top1_att", "acc_top1_mlp",
                    "acc_top3_mlpatt", "acc_top3_att", "acc_top3_mlp",
                    "timestamp"
                ])
            writer.writerow([
                dataset_name_for_log, total,
                results_correct.get("Pretrained", 0),
                results_correct.get("Full_MLPATT", 0), results_correct.get("Full_ATT", 0), results_correct.get("Full_MLP", 0),
                results_correct.get("Top1_MLPATT", 0), results_correct.get("Top1_ATT", 0), results_correct.get("Top1_MLP", 0),
                results_correct.get("Top3_MLPATT", 0), results_correct.get("Top3_ATT", 0), results_correct.get("Top3_MLP", 0),
                f"{results_acc.get('Pretrained', 0.0):.4f}",
                f"{results_acc.get('Full_MLPATT', 0.0):.4f}", f"{results_acc.get('Full_ATT', 0.0):.4f}", f"{results_acc.get('Full_MLP', 0.0):.4f}",
                f"{results_acc.get('Top1_MLPATT', 0.0):.4f}", f"{results_acc.get('Top1_ATT', 0.0):.4f}", f"{results_acc.get('Top1_MLP', 0.0):.4f}",
                f"{results_acc.get('Top3_MLPATT', 0.0):.4f}", f"{results_acc.get('Top3_ATT', 0.0):.4f}", f"{results_acc.get('Top3_MLP', 0.0):.4f}",
                time.strftime("%Y-%m-%d %H:%M:%S")
            ])
        print(f"Wrote CSV summary to: {CSV_OUT}")
    except Exception as e:
        warnings.warn(f"Could not write CSV summary to {CSV_OUT}: {e}")

    # PNG table (10 columns)
    try:
        import matplotlib.pyplot as plt
        headers = ["Pretrained",
                   "Full MLP+ATT", "Full ATT", "Full MLP",
                   "Top1 MLP+ATT", "Top1 ATT", "Top1 MLP",
                   "Top3 MLP+ATT", "Top3 ATT", "Top3 MLP"]
        vals = [f"{results_acc.get('Pretrained', 0.0):.4f}",
                f"{results_acc.get('Full_MLPATT', 0.0):.4f}", f"{results_acc.get('Full_ATT', 0.0):.4f}", f"{results_acc.get('Full_MLP', 0.0):.4f}",
                f"{results_acc.get('Top1_MLPATT', 0.0):.4f}", f"{results_acc.get('Top1_ATT', 0.0):.4f}", f"{results_acc.get('Top1_MLP', 0.0):.4f}",
                f"{results_acc.get('Top3_MLPATT', 0.0):.4f}", f"{results_acc.get('Top3_ATT', 0.0):.4f}", f"{results_acc.get('Top3_MLP', 0.0):.4f}"]
        fig, ax = plt.subplots(figsize=(12, 2.8))
        ax.axis('off')
        # build table rows: header row then values row
        table = ax.table(cellText=[headers, vals], loc='center', cellLoc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 2.0)
        png_path = os.path.join(LOG_DIR, f"{dataset_name_for_log}_accuracy_table.png")
        plt.tight_layout()
        plt.savefig(png_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        print("Saved accuracy PNG table to:", png_path)
    except Exception as e:
        warnings.warn(f"Could not create/save PNG accuracy table: {e}")

    if log_fh:
        log_fh.close()
